function [output] = CCE_main(input_array)
% This function uses a sequence of clustering structure to perform
% hypothesis testing using clustered standard standard error. 
% "test_*_star" is the test result using the clustering selected 
% according to a size-power tradeoff criterion. 


%% DATA INPUT
method = input_array.method;
data = input_array.data;
Sigma_hat = input_array.Sigma_hat;
group_matrix = input_array.group_matrix_km;
alpha_sig = input_array.alpha_sig;
alt_vec = input_array.alt_vec;
alt_vec_tradeoff = input_array.alt_vec_tradeoff;


%% adjust p-value thresholds to control sizes for each clustering
G_vec = zeros(1,size(group_matrix,2)); % vector of group numbers
for i = 1 : size(group_matrix,2)
    G_vec(i) = numel(unique(group_matrix(:,i)));
end
l_g = length(G_vec);

% t1_vec = zeros(1,length(G_vec));
ave_power_vec = zeros(1,length(G_vec));
p_adj_vec = zeros(1,length(G_vec));
rej_rate_vec = zeros(1,length(G_vec));
for jj = 1 : length(G_vec)

    membership = group_matrix(:,jj);
    
    [rej_rate,p_adj] = rej_rate_CCE_mean(Sigma_hat,membership,...
        alpha_sig,0);
    rej_rate_vec(jj) = min(rej_rate,alpha_sig);
    p_adj = min(p_adj,alpha_sig); % cannot be more aggressive
    
    power_vec = zeros(length(alt_vec_tradeoff),1);
    for i_alt = 1 : length(alt_vec_tradeoff)
        power_vec(i_alt) =  rej_rate_CCE_mean(Sigma_hat,...
            membership,p_adj,alt_vec_tradeoff(i_alt));
    end
    ave_power = mean(power_vec);

    ave_power_vec(jj) = ave_power;
    p_adj_vec(jj) = p_adj;
end


%% select the one with highest power
[ave_power,ind_G] = max(ave_power_vec);
group_star = group_matrix(:,ind_G);
p_adj = p_adj_vec(ind_G);
G_star = G_vec(ind_G);
estimated_size = rej_rate_vec(ind_G);


%% clustered s.e. using the selected clustering
switch lower(method)
    case 'ols'

        D = data.D;
        X = data.X;
        Y = data.Y;
        
        X_mat = [D X];
        betaHat_ols = X_mat\Y;
        resid = Y-X_mat*betaHat_ols;
        iXX = inv(X_mat'*X_mat);
        beta0_hat = betaHat_ols(1); % OLS estimate

        sctmp = cluster_se(X_mat,resid,iXX,group_star);
        se_cl = sctmp(1); % clustered standard error
        
        alt_power_curve = linspace(alt_vec(1),alt_vec(end),50);

    case 'iv'
        
        D = data.D;
        X = data.X;
        Y = data.Y;
        Z = data.Z;
        
        X_mat = [D X];
        Z_mat = [Z X];

        beta_iv = (Z_mat'*X_mat)\(Z_mat'*Y);
        beta0_hat = beta_iv(1);
        resid = Y-X_mat*beta_iv;
        iZX = inv(Z_mat'*X_mat);
        sctmp = cluster_se(Z_mat,resid,iZX,group_star);
        se_cl = sctmp(1); 
                
        alt_power_curve = linspace(alt_vec(2),alt_vec(end-1),50);

end


%% OUTPUT for estimation and inference table
% p-value under the null
t_stat = (beta0_hat-0)/se_cl;
p_null = 2*(1-tcdf(abs(t_stat),max(group_star)-1));

% p-value under the alternatives
t_stat_alt = (beta0_hat+alt_vec-0)/se_cl;
p_alt = 2*(1-tcdf(abs(t_stat_alt),max(group_star)-1));

output.estimate = beta0_hat;
output.p_adj = p_adj;
output.p_null = p_null;
output.p_alt = p_alt;


%% p-value with fixed G

p_fixed_G = zeros(1,length(G_vec));

for i_g = 1 : l_g
group = group_matrix(:,i_g);

switch lower(method)
    case 'ols'

        sctmp = cluster_se(X_mat,resid,iXX,group);
        se_cl = sctmp(1); % clustered standard error
        
    case 'iv'
        
        sctmp = cluster_se(Z_mat,resid,iZX,group);
        se_cl = sctmp(1); 
        
end

t_stat = (beta0_hat-0)/se_cl;
p_fixed_G(i_g) = 2*(1-tcdf(abs(t_stat),max(group)-1));

end


%% OUTPUT for clustering table
output.G_star = G_star;
output.ave_power = ave_power;
output.p_fixed_G = p_fixed_G;
output.p_threshold_fixed_G = p_adj_vec;
output.estimated_size = estimated_size;
% p-value used for comparing actual power and estimated power
t_stat_tradeoff = (beta0_hat+alt_vec_tradeoff-0)/se_cl;
p_trade_off = 2*(1-tcdf(abs(t_stat_tradeoff),max(group_star)-1));
output.p_tradeoff = p_trade_off;


%% p-value used for power curve
% alt_power_curve = linspace(alt_vec(1),alt_vec(end),50);
t_stat_power_curve = (beta0_hat+alt_power_curve-0)/se_cl;
p_power_curve = 2*(1-tcdf(abs(t_stat_power_curve),max(group_star)-1));


%% OUTPUT for power curve
output.alt_power_curve = alt_power_curve;
output.p_power_curve = p_power_curve;








